## IMPORT

from pyspark.context import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window, types
from pyspark.sql.types import StructType, IntegerType
from pyspark.sql.context import SQLContext
import pyspark.sql.functions as f

from math import floor, ceil
import sys, json
import pandas as pd

from splink.spark.spark_linker import SparkLinker
from splink.spark.spark_comparison_library import jaro_winkler_at_thresholds, levenshtein_at_thresholds, exact_match
from splink.charts import save_offline_chart

## SPARK SESSION 

sc = SparkContext.getOrCreate(conf=SparkConf())
spark = SparkSession(sc)
spark.sparkContext.setLogLevel('WARN')
sc.setCheckpointDir("splink_sandbox/temp_graphframes/")
spark.udf.registerJavaFunction('jaro_winkler','uk.gov.moj.dash.linkage.JaroWinklerSimilarity',types.DoubleType())

## LOAD DATA FROM SLURM PARAMS

state = ["AL", "AK", "AZ", "AR", "CA", "CO", "CT", "DE", "DC", "FL", "GA", "HI",
"ID", "IL", "IN", "IA", "KS", "KY", "LA", "ME", "MD", "MA", "MI", "MN", "MS", "MO",
"MT", "NE", "NV", "NH", "NJ", "NM", "NY", "NC", "ND", "OH", "OK", "OR", "PA", "RI",
"SC", "SD", "TN", "TX", "UT", "VT", "VA", "WA", "WV", "WI", "WY"][int(sys.argv[1])].lower()

year = int(sys.argv[2])

print(f"Working in {state} {year}...")

cf = spark.read.parquet(f"data/fec_contributors") \
  .filter(f"contbr_addr IS NOT NULL AND gave_{year} = 1") \
  .withColumnRenamed('cluster', 'uid') \
  .withColumnRenamed('contbr_nm_first', 'first_only') \
  .withColumnRenamed('contbr_nm_last', 'last_only') \
  .withColumnRenamed('contbr_name_fm', 'first_m') \
  .withColumnRenamed('contbr_name_ls', 'last') \
  .withColumnRenamed('contbr_addr', 'address') \
  .withColumnRenamed('contbr_city', 'city') \
  .withColumnRenamed('contbr_zip', 'zip') \
  .withColumnRenamed('contbr_st', 'state') \
  .select('uid', 'gender', 'first_only', 'last_only', 'first_m', 'last', 'address', 'city', 'zip', 'state')

cf = cf.repartition(int(ceil(cf.count()/1000)))

print(f"There are {cf.count()} FEC observations...")

vf = spark.read.parquet(f"data/l2/year={'2014' if year < 2014 else str(year)}") \
  .filter(f"address IS NOT NULL AND state = \'{state}\'") \
  .select('emmid', 'gender', 'zip', 'city', 'address', 'first_only', 'last_only', 'first_m', 'last_sf', 'state') \
  .withColumnRenamed('last_sf', 'last') \
  .withColumnRenamed('emmid', 'uid')

vf = vf.repartition(int(ceil(vf.count()/1000)))

print(f"There are {vf.count()} L2 observations...")

## SETTINGS

no_comp = f" AND abs(l.gender - r.gender) < 1.001"

br1 = "l.zip = r.zip AND l.first_only = r.first_only" + no_comp
br2 = "l.zip = r.zip AND l.last_only = r.last_only" + no_comp
br3 = "l.state = r.state AND l.last_only = r.last_only AND substring(l.first_only,1,2) = substring(r.first_only,1,2) AND l.city = r.city" + no_comp
br4 = "l.state = r.state AND l.first_only = r.first_only AND substring(l.last_only,1,2) = substring(r.last_only,1,2) AND l.city = r.city" + no_comp
br5 = "l.state = r.state AND l.address = r.address AND substring(l.first_only,1,1) = substring(r.first_only,1,1)" + no_comp
br6 = "l.state = r.state AND l.address = r.address AND substring(l.last_only,1,1) = substring(r.last_only,1,1)" + no_comp

settings = {
  
  "link_type": "link_only",
  
  "blocking_rules_to_generate_predictions": [
    {"blocking_rule": br1, "salting_partitions": 4},
    {"blocking_rule": br2, "salting_partitions": 4},
    {"blocking_rule": br3, "salting_partitions": 4},
    {"blocking_rule": br4, "salting_partitions": 4},
    {"blocking_rule": br5, "salting_partitions": 1},
    {"blocking_rule": br6, "salting_partitions": 1},
  ],
  
  "comparisons": [
        jaro_winkler_at_thresholds(
          "first_m",
          [0.94, 0.86, 0.8],
          include_exact_match_level = True,
          term_frequency_adjustments = True
          ),
        jaro_winkler_at_thresholds(
          "last",
          [0.94, 0.86, 0.8],
          include_exact_match_level = True,
          term_frequency_adjustments = True
          ),
        jaro_winkler_at_thresholds(
          "address",
          [0.94, 0.90, 0.86],
          include_exact_match_level = True,
          term_frequency_adjustments = False
          ),
        jaro_winkler_at_thresholds(
          "city",
          [0.94,0.88],
          include_exact_match_level = False,
          term_frequency_adjustments = True
        )
    ],
  
  "retain_matching_columns": True,
  "retain_intermediate_calculation_columns": False,
  "unique_id_column_name": "uid",
  "max_iterations": 40
  
}

cf_train = cf.filter(f"state = '{state}'")
cf_train = cf_train.repartition(int(ceil(cf_train.count()/1000)))

linker = SparkLinker(input_table_or_tables = [cf_train,vf], settings_dict = settings, spark = spark)

linker.estimate_u_using_random_sampling(target_rows = vf.count())

linker.estimate_parameters_using_expectation_maximisation(
  "l.zip = r.zip AND l.first_only = r.first_only",
  comparisons_to_deactivate = ["city", "first_m"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("city") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("first_m") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.estimate_parameters_using_expectation_maximisation(
  "l.zip = r.zip AND l.last_only = r.last_only",
  comparisons_to_deactivate = ["city", "last"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("city") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
    ._get_comparison_by_output_column_name("last") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.estimate_parameters_using_expectation_maximisation(
  "l.last_only = r.last_only AND l.first_only = r.first_only",
  comparisons_to_deactivate = ["last", "first_m"],
  comparison_levels_to_reverse_blocking_rule = [
    linker._settings_obj \
    ._get_comparison_by_output_column_name("last") \
    ._get_comparison_level_by_comparison_vector_value(1),
    linker._settings_obj \
     ._get_comparison_by_output_column_name("first_m") \
    ._get_comparison_level_by_comparison_vector_value(1)
  ],
  fix_u_probabilities=True
)

linker.save_settings_to_json(f"logs/{state}_{year}_fec_l2_addr.json", overwrite = True)
save_offline_chart(linker.match_weights_chart().spec, f"logs/{state}_{year}_fec_l2_addr.html", overwrite = True)

linker = SparkLinker([cf,vf], spark = spark)

in_path = f"logs/{state}_{year}_fec_l2_addr.json"

with open(in_path, "r") as f:
    model_dict = json.load(f)

linker.initialise_settings(model_dict)

df = linker.predict(threshold_match_probability = 0.5) \
  .as_spark_dataframe() \
  .select("match_probability", "uid_l", "uid_r", "gamma_last", "gamma_first_m", "gamma_address") \
  .filter("(gamma_last > 0 AND gamma_first_m > 0) OR (gamma_last = -1 AND gamma_first_m = 4) OR (gamma_last = 4 AND gamma_first_m = -1)") \
  .filter("gamma_address > 0")

print(f"There are {df.count()} predicted matches with posterior >= 0.5")

df.coalesce(20).write.mode("overwrite").parquet(f"data/fec_l2_matches/state={state}/year={year}/type=a/")
